import copy
import math
# import matplotlib.pyplot as plt
# import numpy as np
# import spline
from cnrp_experiment.exp_cnrp import TestCNRP_BlockInfo
from cnrp_experiment.exp_config import *
from cnrp_experiment.exp_block import *


class TestBlockChain(object):

    # def __init__(self):
    # self.last_height = -1
    # self.last_tx_id = -1

    # # 当前的block
    # self.cur_block = TestBlock(TestBlockHeader(self.last_height+1))

    @staticmethod
    def init_cur_block():
        GlobalDict.cur_block = TestBlock(TestBlockHeader(GlobalDict.last_height + 1))

    # 更新节点的可执行序列
    def update_execute_queue(self, block: TestBlock):
        for tx in block.transactions:
            tx_type = tx.tx_type
            tx_id = tx.tx_id
            source_node = tx.source_node_id
            prev_tx_id = tx.prev_tx_id

            if tx_type == 'upload':
                for k in GlobalDict.node_dict:
                    # TODO replace
                    # GlobalDict.node_dict[k].execute_queue['request'].append(tx_id)
                    GlobalDict.node_dict[k].execute_queue_append('request', tx_id)
                # TODO replace
                # GlobalDict.node_dict[tx.source_node_id].execute_queue['request'].remove(tx_id)
                GlobalDict.node_dict[tx.source_node_id].execute_queue_remove('request', tx_id)
            elif tx_type == 'request':
                prev_tx = self.get_tx_by_id(prev_tx_id)
                # GlobalDict.node_dict[source_node].execute_queue['request'].remove(prev_tx_id)
                # TODO replace
                # GlobalDict.node_dict[prev_tx.source_node_id].execute_queue['share'].append(tx_id)
                GlobalDict.node_dict[prev_tx.source_node_id].execute_queue_append('share', tx_id)
            elif tx_type == 'share':
                prev_tx = self.get_tx_by_id(prev_tx_id)
                # GlobalDict.node_dict[source_node].execute_queue['share'].remove(prev_tx_id)
                # TODO replace
                # GlobalDict.node_dict[prev_tx.source_node_id].execute_queue['evaluate'].append(tx_id)
                GlobalDict.node_dict[prev_tx.source_node_id].execute_queue_append('evaluate', tx_id)
            elif tx_type == 'evaluate':
                pass
                # prev_tx = self.get_tx_by_id(prev_tx_id)
                # GlobalDict.node_dict[source_node].execute_queue['evaluate'].remove(prev_tx_id)
            else:
                return

    def add_tx(self, tx: TestTransaction):
        if tx.tx_type == 'upload':
            tx.set_data_id_upload()
        elif tx.tx_type == 'request':
            tx.data_id = self.get_tx_by_id(tx.prev_tx_id).data_id
            tx.business_id = tx.data_id + '-' + str(tx.source_node_id)
        else:
            prev_tx = self.get_tx_by_id(tx.prev_tx_id)
            tx.data_id = prev_tx.data_id
            tx.business_id = prev_tx.business_id
        # self.cur_block.add_tx(tx)
        GlobalDict.cur_block.add_tx(tx)

        # self.last_tx_id += 1
        GlobalDict.last_tx_id += 1
        # if self.cur_block.is_full:
        if GlobalDict.cur_block.is_full:
            # 重要操作 - 更新 block 的 rp
            # block = self.update_block_rp(self.cur_block)

            _time_s = time.time()

            block = self.update_block_rp(GlobalDict.cur_block)

            _time_e = time.time()
            _time = round(_time_e - _time_s, 4)
            GlobalDict.time_consume['RP_block'][GlobalDict.last_height+1] = _time
            GlobalDict.time_consume['RP_total'] += _time

            # GlobalDict.chain_dict[self.last_height + 1] = block
            GlobalDict.chain_dict[GlobalDict.last_height + 1] = block
            # self.last_height += 1
            GlobalDict.last_height += 1
            self.update_execute_queue(block)
            # self.cur_block = TestBlock(TestBlockHeader(self.last_height + 1))
            # GlobalDict.cur_block = TestBlock(TestBlockHeader(GlobalDict.last_height + 1))
            self.init_cur_block()

        # print(GlobalDict.node_dict)

    @staticmethod
    def create_genesis_block():
        block = TestBlock(TestBlockHeader(-1))
        block.transactions.append(TestTransaction(-1, 'upload', -1))
        GlobalDict.chain_dict[-1] = block

    @staticmethod
    def get_block_by_height(height: int):
        return GlobalDict.chain_dict[height]

    @staticmethod
    def get_tx_by_id(tx_id: int):
        height = tx_id // BLOCK_TX_NUM
        if height > GlobalDict.last_height:
            return GlobalDict.cur_block.transactions[tx_id % BLOCK_TX_NUM]
        return GlobalDict.chain_dict[tx_id // BLOCK_TX_NUM].transactions[tx_id % BLOCK_TX_NUM]

    # 核心 rp 值计算函数
    # 对于给定的 cnrp 为空的 block , 计算后返回填充 cnrp 值后的block
    # 修改第2版 加入对 block 和 tx 关系的更新 即单个 block 内可包含若干个 tx 数目可定义
    def update_block_rp(self, block: TestBlock):

        # 计算过程中 规范小数保留位数(加快计算速度)
        def _round(_x):
            return round(_x, 6)

        # 计算 S
        def _compute_S(_pas_neg):
            # _S_top = _pas_neg[0] - CNRP_b_punish * _pas_neg[1]
            _S_top = _pas_neg[0]
            if _S_top <= 0:
                return 0
            _S_btm = _pas_neg[0] + CNRP_b_punish * _pas_neg[1]
            return _S_top / _S_btm

        # 计算 E 的单项分子分母
        def _compute_E_factor(_pas_neg_num_list, _R_source_node):
            _S = _compute_S(_pas_neg_num_list)
            if _S <= 0:
                _S = 0
            # 以下计算 E_b
            _speed_factor_e = 1 - math.exp(-sum(_pas_neg_num_list))
            _speed_factor_e = _round(_speed_factor_e)

            _speed_factor_top = (_R_source_node ** 2) * _speed_factor_e
            _speed_factor_top = _round(_speed_factor_top)

            # 计算 E_b 分子 (单项)
            _E_fraction_top = _speed_factor_top * _S
            _E_fraction_top = _round(_E_fraction_top)

            # 计算 E_b 分母 (单项)
            _E_fraction_btm = _speed_factor_top
            _E_fraction_btm = _round(_E_fraction_btm)

            return [_E_fraction_top, _E_fraction_btm]

        # 计算 E 的虚拟节点增益值
        def _compute_E_gain(_cur_evaluate_node):
            _gain_num = CNRP_threshold_rp_num - _cur_evaluate_node
            if _gain_num <= 0:
                return 0, 0
            _speed_e = 1 - math.exp(-1)
            _speed_e = _round(_speed_e)

            _fraction_btm_gain = (CNRP_R_init ** 2) * _speed_e * _gain_num
            _fraction_btm_gain = _round(_fraction_btm_gain)

            return [_fraction_btm_gain * CNRP_S_virtual_node, _fraction_btm_gain]

        # 计算两个 dict 中不重复 key
        def _get_keys(_dict1, _dict2):
            _keys = []
            for _k, _v in _dict1.items():
                if _k not in _keys:
                    _keys.append(_k)
            for _k, _v in _dict2.items():
                if _k not in _keys:
                    _keys.append(_k)
            return _keys

        # 统计 num
        def _count(_block: TestBlock):
            _B_num = {}
            _P_N_num = {}
            for _tx in _block.transactions:
                _source_node_id = _tx.source_node_id
                if _tx.tx_type in ['upload', 'share', 'evaluate']:
                    _num = _B_num.get(_source_node_id, 0) + 1
                    _B_num[_source_node_id] = _num
                    if _tx.tx_type == 'evaluate':
                        _target_node_id = self.get_tx_by_id(_tx.prev_tx_id).source_node_id

                        if _target_node_id in _P_N_num:
                            _num = _P_N_num[_target_node_id].get(_source_node_id, [0, 0])
                            if _tx.data_score < 0.5:
                                _num[1] += 1
                            else:
                                _num[0] += 1
                            _P_N_num[_target_node_id][_source_node_id] = _num
                        else:
                            if _tx.data_score < 0.5:
                                _P_N_num[_target_node_id] = {_source_node_id: [0, 1]}
                            else:
                                _P_N_num[_target_node_id] = {_source_node_id: [1, 0]}

            _block_info = TestCNRP_BlockInfo()
            _block_info.B_num = _B_num
            _block_info.P_N_num = _P_N_num
            return _block_info

        # main:
        height = block.block_header.height
        prev_cnrp = GlobalDict.chain_dict[GlobalDict.last_height].block_header.cnrp
        cnrp = copy.deepcopy(prev_cnrp)

        cnrp.block_info = _count(block)

        # 时间窗口 注：时间窗口只影响 B 的计算
        if height > CNRP_WINDOW_LENGTH:
            compare_cnrp = GlobalDict.chain_dict[height - CNRP_WINDOW_LENGTH].block_header.cnrp
            _B_num = cnrp.compute_temp.B_num
            for k, v in compare_cnrp.block_info.B_num.items():
                _B_num[k] -= v
            cnrp.compute_temp.B_num = _B_num

        # 更新 Temp.B_num 值
        for node, num in cnrp.block_info.B_num.items():

            _temp_num = cnrp.compute_temp.B_num.get(node, 0) + num

            cnrp.compute_temp.B_num[node] = _temp_num

            # 计算 B
            # TODO
            # B = 1 - math.exp(- ((CNRP_B_positive * _temp_num) + 0.7))
            # init:
            B = 1 - math.exp(- CNRP_B_positive * _temp_num)
            # sigmoid func:
            # B = 1 / (1 + math.exp(-(_temp_num - 10)))
            B = _round(B)

            cnrp.compute_temp.B[node] = B

        # 更新 Temp.P_N_num 和 Temp.E_fraction 值
        for target_node, record_dict in cnrp.block_info.P_N_num.items():

            for source_node, num_list in record_dict.items():

                R_source_node = prev_cnrp.value.get(source_node, CNRP_R_init)

                # 取历史值
                prev_target_node_record_dict = prev_cnrp.compute_temp.P_N_num.get(target_node, {})
                prev_target_node_record_source_node_list = prev_target_node_record_dict.get(source_node, [0, 0, 0, 0])
                prev_target_node_fraction = prev_cnrp.compute_temp.E_fraction.get(target_node, [0, 0])

                # 计算当前 P 和 N 值
                cur_P_N_list = [prev_target_node_record_source_node_list[0] + num_list[0],
                                prev_target_node_record_source_node_list[1] + num_list[1]]

                # 获取旧值 用于减差
                E_fraction_diff = prev_target_node_record_source_node_list[2:]

                # 计算当前因式项的 E 分子分母值
                E_fraction_factor = _compute_E_factor(cur_P_N_list, R_source_node)

                # 计算分子分母增益
                E_fraction_gain = _compute_E_gain(len(_get_keys(record_dict, prev_target_node_record_dict)))

                # 计算 E 分子分母值
                E_fraction_top_core = prev_target_node_fraction[0] - E_fraction_diff[0] + E_fraction_factor[0]
                E_fraction_top = E_fraction_top_core + E_fraction_gain[0]

                # if E_fraction_top < 0:
                #     print(E_fraction_top)

                E_fraction_btm_core = prev_target_node_fraction[1] - E_fraction_diff[1] + E_fraction_factor[1]
                E_fraction_btm = E_fraction_btm_core + E_fraction_gain[1]

                # 计算 E
                E = CNRP_E_init
                if E_fraction_btm != 0:
                    E = E_fraction_top / E_fraction_btm
                    E = _round(E)

                # 将 P_N 值更新给 target_node
                if target_node not in cnrp.compute_temp.P_N_num:
                    cnrp.compute_temp.P_N_num[target_node] = {}
                cnrp.compute_temp.P_N_num[target_node][source_node] = cur_P_N_list + E_fraction_factor

                # 将 E_fraction 值更新给 target_node
                cnrp.compute_temp.E_fraction[target_node] = [E_fraction_top_core, E_fraction_btm_core]

                cnrp.compute_temp.E[target_node] = E
                # cnrp.compute_temp.E[target_node] = 1

        # 计算更新 R
        nodes = _get_keys(cnrp.compute_temp.B, cnrp.compute_temp.E)
        for node in nodes:
            B = cnrp.compute_temp.B.get(node)
            E = cnrp.compute_temp.E.get(node, CNRP_E_init)
            R = B * E
            R = _round(R)
            cnrp.value[node] = R

        block.block_header.cnrp = cnrp
        return block

    # 获取当前 Global 变量中的排序信誉字典 默认降序 descend 升序 ascend
    # @return: list [(node_id, rp), ..]
    @staticmethod
    def get_cur_cnrp(sort_type='descend'):
        sort_reverse = True if sort_type == 'descend' else False
        cnrp = copy.deepcopy(GlobalDict.chain_dict[GlobalDict.last_height].block_header.cnrp.value)
        cnrp = sorted(cnrp.items(), key=lambda kv: (kv[1], kv[0]), reverse=sort_reverse)
        return cnrp

